{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import collections\n",
    "from sklearn.cross_validation import train_test_split\n",
    "from tensor2tensor.utils import beam_search, rouge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset/ctexts.json','r') as fopen:\n",
    "    ctexts = json.load(fopen)\n",
    "    \n",
    "with open('dataset/headlines.json','r') as fopen:\n",
    "    headlines = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from sklearn.decomposition import TruncatedSVD\n",
    "\n",
    "max_len = 300\n",
    "\n",
    "def topic_modelling(string, n = 200):\n",
    "    vectorizer = TfidfVectorizer()\n",
    "    tf = vectorizer.fit_transform([string])\n",
    "    tf_features = vectorizer.get_feature_names()\n",
    "    compose = TruncatedSVD(1).fit(tf)\n",
    "    return ' '.join([tf_features[i] for i in compose.components_[0].argsort()[: -n - 1 : -1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/sklearn/decomposition/truncated_svd.py:192: RuntimeWarning: invalid value encountered in true_divide\n",
      "  self.explained_variance_ratio_ = exp_var / full_var\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 18.8 s, sys: 24.1 s, total: 42.9 s\n",
      "Wall time: 10.8 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "h, c = [], []\n",
    "for i in range(len(ctexts)):\n",
    "    try:\n",
    "        c.append(topic_modelling(ctexts[i], max_len))\n",
    "        h.append(headlines[i])\n",
    "    except:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dataset(words, n_words):\n",
    "    count = [['PAD', 0], ['GO', 1], ['EOS', 2], ['UNK', 3]]\n",
    "    count.extend(collections.Counter(words).most_common(n_words))\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    data = list()\n",
    "    unk_count = 0\n",
    "    for word in words:\n",
    "        index = dictionary.get(word, 0)\n",
    "        if index == 0:\n",
    "            unk_count += 1\n",
    "        data.append(index)\n",
    "    count[0][1] = unk_count\n",
    "    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))\n",
    "    return data, count, dictionary, reversed_dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 47609\n",
      "Most common words [('dot', 4394), ('the', 4379), ('comma', 4349), ('to', 4280), ('in', 4268), ('of', 4262)]\n",
      "Sample data [5, 7, 4, 6, 10, 9, 11, 8, 15, 1990] ['the', 'to', 'dot', 'comma', 'and', 'of', 'on', 'in', 'was', 'festival']\n",
      "filtered vocab size: 47613\n",
      "% of vocab used: 100.01%\n"
     ]
    }
   ],
   "source": [
    "concat = ' '.join(c).split()\n",
    "vocabulary_size = len(list(set(concat)))\n",
    "data, count, dictionary, rev_dictionary = build_dataset(concat, vocabulary_size)\n",
    "print('vocab from size: %d'%(vocabulary_size))\n",
    "print('Most common words', count[4:10])\n",
    "print('Sample data', data[:10], [rev_dictionary[i] for i in data[:10]])\n",
    "print('filtered vocab size:',len(dictionary))\n",
    "print(\"% of vocab used: {}%\".format(round(len(dictionary)/vocabulary_size,4)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'daman and diu revokes mandatory rakshabandhan in offices order EOS'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for i in range(len(h)):\n",
    "    h[i] = h[i] + ' EOS'\n",
    "h[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = dictionary['GO']\n",
    "PAD = dictionary['PAD']\n",
    "EOS = dictionary['EOS']\n",
    "UNK = dictionary['UNK']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def str_idx(corpus, dic, UNK=3):\n",
    "    X = []\n",
    "    for i in corpus:\n",
    "        ints = []\n",
    "        for k in i.split():\n",
    "            ints.append(dic.get(k, UNK))\n",
    "        X.append(ints)\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = str_idx(c, dictionary)\n",
    "Y = str_idx(h, dictionary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X, test_X, train_Y, test_Y = train_test_split(X, Y, test_size = 0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://raw.githubusercontent.com/policeme/transformer-pointer-generator/master/modules.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules import get_token_embeddings, ff, positional_encoding, multihead_attention, noam_scheme"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def position_encoding(inputs):\n",
    "    T = tf.shape(inputs)[1]\n",
    "    repr_dim = inputs.get_shape()[-1].value\n",
    "    pos = tf.reshape(tf.range(0.0, tf.to_float(T), dtype=tf.float32), [-1, 1])\n",
    "    i = np.arange(0, repr_dim, 2, np.float32)\n",
    "    denom = np.reshape(np.power(10000.0, i / repr_dim), [1, -1])\n",
    "    enc = tf.expand_dims(tf.concat([tf.sin(pos / denom), tf.cos(pos / denom)], 1), 0)\n",
    "    return tf.tile(enc, [tf.shape(inputs)[0], 1, 1])\n",
    "\n",
    "class Summarization:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, \n",
    "                 dict_size, learning_rate, \n",
    "                 kernel_size = 2, n_attn_heads = 16):\n",
    "\n",
    "        self.X = tf.placeholder(tf.int32, [None, max_len])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        \n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype = tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype = tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        self.batch_size = batch_size\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        \n",
    "        self.embedding = tf.Variable(tf.random_uniform([dict_size, embedded_size], -1, 1))\n",
    "        \n",
    "        self.num_layers = num_layers\n",
    "        self.kernel_size = kernel_size\n",
    "        self.size_layer = size_layer\n",
    "        self.n_attn_heads = n_attn_heads\n",
    "        self.dict_size = dict_size\n",
    "        \n",
    "        self.training_logits, coverage_loss = self.forward(self.X, decoder_input)\n",
    "        maxlen = tf.reduce_max(self.Y_seq_len)\n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        \n",
    "        targets = tf.slice(self.Y, [0, 0], [-1, maxlen])\n",
    "        i1, i2 = tf.meshgrid(tf.range(batch_size),\n",
    "                     tf.range(maxlen), indexing=\"ij\")\n",
    "        indices = tf.stack((i1,i2,targets),axis=2)\n",
    "        probs = tf.gather_nd(self.training_logits, indices)\n",
    "        probs = tf.where(tf.less_equal(probs,0),tf.ones_like(probs)*1e-10,probs)\n",
    "        crossent = -tf.log(probs)\n",
    "        self.cost = tf.reduce_sum(crossent * masks) / tf.to_float(batch_size)\n",
    "        self.coverage_loss = tf.reduce_sum(coverage_loss / tf.to_float(batch_size))\n",
    "        self.cost = self.cost + self.coverage_loss\n",
    "        \n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n",
    "        \n",
    "        \n",
    "    def _calc_final_dist(self, x, gens, vocab_dists, attn_dists):\n",
    "        with tf.variable_scope('final_distribution', reuse=tf.AUTO_REUSE):\n",
    "            vocab_dists = gens * vocab_dists\n",
    "            attn_dists = (1-gens) * attn_dists\n",
    "            batch_size = tf.shape(attn_dists)[0]\n",
    "            dec_t = tf.shape(attn_dists)[1]\n",
    "            attn_len = tf.shape(attn_dists)[2]\n",
    "\n",
    "            dec = tf.range(0, limit=dec_t) # [dec]\n",
    "            dec = tf.expand_dims(dec, axis=-1) # [dec, 1]\n",
    "            dec = tf.tile(dec, [1, attn_len]) # [dec, atten_len]\n",
    "            dec = tf.expand_dims(dec, axis=0) # [1, dec, atten_len]\n",
    "            dec = tf.tile(dec, [batch_size, 1, 1]) # [batch_size, dec, atten_len]\n",
    "\n",
    "            x = tf.expand_dims(x, axis=1) # [batch_size, 1, atten_len]\n",
    "            x = tf.tile(x, [1, dec_t, 1]) # [batch_size, dec, atten_len]\n",
    "            x = tf.stack([dec, x], axis=3)\n",
    "\n",
    "            attn_dists_projected = tf.map_fn(fn=lambda y: \\\n",
    "                                            tf.scatter_nd(y[0], y[1], [dec_t, self.dict_size]),\n",
    "                                            elems=(x, attn_dists), dtype=tf.float32)\n",
    "\n",
    "            final_dists = attn_dists_projected + vocab_dists\n",
    "            return final_dists\n",
    "        \n",
    "    def forward(self, x, y, reuse = False):\n",
    "        with tf.variable_scope('forward',reuse=reuse):\n",
    "            with tf.variable_scope('forward',reuse=reuse):\n",
    "                encoder_embedded = tf.nn.embedding_lookup(self.embedding, x)\n",
    "                decoder_embedded = tf.nn.embedding_lookup(self.embedding, y)\n",
    "\n",
    "                encoder_embedded += position_encoding(encoder_embedded)\n",
    "                decoder_embedded += position_encoding(decoder_embedded)\n",
    "                enc = encoder_embedded\n",
    "\n",
    "                for i in range(self.num_layers):\n",
    "                    enc, _ = multihead_attention(queries=enc,\n",
    "                                                      keys=enc,\n",
    "                                                      values=enc,\n",
    "                                                      num_heads=self.n_attn_heads,\n",
    "                                                      dropout_rate=0.0,\n",
    "                                                      training=True,\n",
    "                                                      causality=False)\n",
    "                    enc = ff(enc, num_units=[self.size_layer * 2, self.size_layer])\n",
    "\n",
    "                memory = enc\n",
    "                dec = decoder_embedded\n",
    "\n",
    "                attn_dists = []\n",
    "                for i in range(self.num_layers):\n",
    "                    with tf.variable_scope(\"num_blocks_{}\".format(i), reuse=tf.AUTO_REUSE):\n",
    "                        dec, _ = multihead_attention(queries=dec,\n",
    "                                                     keys=dec,\n",
    "                                                     values=dec,\n",
    "                                                     num_heads=self.n_attn_heads,\n",
    "                                                     dropout_rate=0.0,\n",
    "                                                     training=True,\n",
    "                                                     causality=True,\n",
    "                                                     scope=\"self_attention\")\n",
    "\n",
    "                        dec, attn_dist = multihead_attention(queries=dec,\n",
    "                                                            keys=memory,\n",
    "                                                            values=memory,\n",
    "                                                            num_heads=self.n_attn_heads,\n",
    "                                                            dropout_rate=0.0,\n",
    "                                                            training=True,\n",
    "                                                            causality=False,\n",
    "                                                            scope=\"vanilla_attention\")\n",
    "                        print(attn_dist)\n",
    "                        attn_dists.append(attn_dist)\n",
    "                        dec = ff(dec, num_units=[self.size_layer * 2, self.size_layer])\n",
    "\n",
    "                weights = tf.transpose(self.embedding)\n",
    "                logits = tf.einsum('ntd,dk->ntk', dec, weights)\n",
    "                print(decoder_embedded, dec, attn_dists[-1])\n",
    "\n",
    "                with tf.variable_scope(\"gen\", reuse=tf.AUTO_REUSE):\n",
    "                    gens = tf.layers.dense(tf.concat([decoder_embedded, dec, attn_dists[-1]], axis=-1), \n",
    "                                               units=1, activation=tf.sigmoid, use_bias=False)\n",
    "                    \n",
    "                logits = tf.nn.softmax(logits)\n",
    "                \n",
    "                print(gens)\n",
    "                alignment_history = tf.transpose(attn_dists[-1],[1,2,0])\n",
    "                coverage_loss = tf.minimum(alignment_history,tf.cumsum(alignment_history, axis=2, exclusive=True))\n",
    "                        \n",
    "                return self._calc_final_dist(x, gens, logits, attn_dists[-1]), coverage_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 256\n",
    "num_layers = 4\n",
    "embedded_size = 256\n",
    "learning_rate = 1e-3\n",
    "batch_size = 16\n",
    "epoch = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def beam_search_decoding(length = 20, beam_width = 5):\n",
    "    initial_ids = tf.fill([model.batch_size], GO)\n",
    "    \n",
    "    def symbols_to_logits(ids):\n",
    "        x = tf.contrib.seq2seq.tile_batch(model.X, beam_width)\n",
    "        logits, _ = model.forward(x, ids, reuse = True)\n",
    "        return logits[:, tf.shape(ids)[1]-1, :]\n",
    "\n",
    "    final_ids, final_probs = beam_search.beam_search(\n",
    "        symbols_to_logits,\n",
    "        initial_ids,\n",
    "        beam_width,\n",
    "        length,\n",
    "        len(dictionary),\n",
    "        0.0,\n",
    "        eos_id = EOS)\n",
    "    \n",
    "    return final_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(\"forward/forward/num_blocks_0/vanilla_attention/scaled_dot_product_attention/Reshape_1:0\", shape=(?, ?, 300), dtype=float32)\n",
      "Tensor(\"forward/forward/num_blocks_1/vanilla_attention/scaled_dot_product_attention/Reshape_1:0\", shape=(?, ?, 300), dtype=float32)\n",
      "Tensor(\"forward/forward/num_blocks_2/vanilla_attention/scaled_dot_product_attention/Reshape_1:0\", shape=(?, ?, 300), dtype=float32)\n",
      "Tensor(\"forward/forward/num_blocks_3/vanilla_attention/scaled_dot_product_attention/Reshape_1:0\", shape=(?, ?, 300), dtype=float32)\n",
      "Tensor(\"forward/forward/add_1:0\", shape=(?, ?, 256), dtype=float32) Tensor(\"forward/forward/num_blocks_3/positionwise_feedforward/ln/add_1:0\", shape=(?, ?, 256), dtype=float32) Tensor(\"forward/forward/num_blocks_3/vanilla_attention/scaled_dot_product_attention/Reshape_1:0\", shape=(?, ?, 300), dtype=float32)\n",
      "Tensor(\"forward/forward/gen/dense/Sigmoid:0\", shape=(?, ?, 1), dtype=float32)\n",
      "Tensor(\"while/forward/forward/num_blocks_0/vanilla_attention/scaled_dot_product_attention/Reshape_1:0\", shape=(?, ?, 300), dtype=float32)\n",
      "Tensor(\"while/forward/forward/num_blocks_1/vanilla_attention/scaled_dot_product_attention/Reshape_1:0\", shape=(?, ?, 300), dtype=float32)\n",
      "Tensor(\"while/forward/forward/num_blocks_2/vanilla_attention/scaled_dot_product_attention/Reshape_1:0\", shape=(?, ?, 300), dtype=float32)\n",
      "Tensor(\"while/forward/forward/num_blocks_3/vanilla_attention/scaled_dot_product_attention/Reshape_1:0\", shape=(?, ?, 300), dtype=float32)\n",
      "Tensor(\"while/forward/forward/add_1:0\", shape=(?, ?, 256), dtype=float32) Tensor(\"while/forward/forward/num_blocks_3/positionwise_feedforward/ln/add_1:0\", shape=(?, ?, 256), dtype=float32) Tensor(\"while/forward/forward/num_blocks_3/vanilla_attention/scaled_dot_product_attention/Reshape_1:0\", shape=(?, ?, 300), dtype=float32)\n",
      "Tensor(\"while/forward/forward/gen/dense/Sigmoid:0\", shape=(?, ?, 1), dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Summarization(size_layer, num_layers, embedded_size, \n",
    "                      len(dictionary), learning_rate)\n",
    "model.generate = beam_search_decoding()\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_sentence_batch(sentence_batch, pad_int, maxlen = None):\n",
    "    padded_seqs = []\n",
    "    seq_lens = []\n",
    "    max_sentence_len = max([len(sentence) for sentence in sentence_batch])\n",
    "    if maxlen:\n",
    "        max_sentence_len = maxlen\n",
    "    for sentence in sentence_batch:\n",
    "        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
    "        seq_lens.append(len(sentence))\n",
    "    return padded_seqs, seq_lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [02:05<00:00,  2.14it/s, accuracy=0.122, cost=72.7, rouge_2=0]       \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:02<00:00,  5.80it/s, accuracy=0.0963, cost=78.4, rouge_2=0.00926]\n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, avg loss: 92.744266, avg accuracy: 0.080139\n",
      "epoch: 0, avg loss test: 77.564087, avg accuracy test: 0.101171\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [02:04<00:00,  2.13it/s, accuracy=0.0899, cost=73.9, rouge_2=0]      \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:02<00:00,  5.99it/s, accuracy=0.0815, cost=69, rouge_2=0.0104]   \n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, avg loss: 70.117131, avg accuracy: 0.101949\n",
      "epoch: 1, avg loss test: 71.875333, avg accuracy test: 0.106957\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [02:04<00:00,  2.15it/s, accuracy=0.0983, cost=68.6, rouge_2=0]      \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:02<00:00,  5.85it/s, accuracy=0.13, cost=78.7, rouge_2=0.00617]  \n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, avg loss: 65.772161, avg accuracy: 0.117132\n",
      "epoch: 2, avg loss test: 70.763171, avg accuracy test: 0.103186\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [02:04<00:00,  2.15it/s, accuracy=0.125, cost=63.8, rouge_2=0.00667]\n",
      "test minibatch loop: 100%|██████████| 14/14 [00:02<00:00,  5.92it/s, accuracy=0.112, cost=70.3, rouge_2=0.0104] \n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 3, avg loss: 63.511224, avg accuracy: 0.126839\n",
      "epoch: 3, avg loss test: 70.971090, avg accuracy test: 0.107998\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [02:04<00:00,  2.14it/s, accuracy=0.115, cost=56.1, rouge_2=0.0239]  \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:02<00:00,  5.74it/s, accuracy=0.133, cost=69.3, rouge_2=0]      \n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, avg loss: 60.778111, avg accuracy: 0.134995\n",
      "epoch: 4, avg loss test: 72.008554, avg accuracy test: 0.117280\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [02:04<00:00,  2.13it/s, accuracy=0.115, cost=54.8, rouge_2=0.00833] \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:02<00:00,  5.99it/s, accuracy=0.111, cost=70.7, rouge_2=0.00794]\n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 5, avg loss: 57.611743, avg accuracy: 0.145993\n",
      "epoch: 5, avg loss test: 75.203302, avg accuracy test: 0.106658\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [02:04<00:00,  2.13it/s, accuracy=0.15, cost=53.8, rouge_2=0.0333]  \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:02<00:00,  5.76it/s, accuracy=0.122, cost=68.9, rouge_2=0.0111]\n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 6, avg loss: 56.036153, avg accuracy: 0.146299\n",
      "epoch: 6, avg loss test: 75.403192, avg accuracy test: 0.108965\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [02:04<00:00,  2.15it/s, accuracy=0.165, cost=52.6, rouge_2=0.0424] \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:02<00:00,  5.98it/s, accuracy=0.0952, cost=67.7, rouge_2=0]     \n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 7, avg loss: 52.667047, avg accuracy: 0.161405\n",
      "epoch: 7, avg loss test: 76.735556, avg accuracy test: 0.116259\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [02:04<00:00,  2.13it/s, accuracy=0.175, cost=46.7, rouge_2=0.0162] \n",
      "test minibatch loop: 100%|██████████| 14/14 [00:02<00:00,  5.96it/s, accuracy=0.096, cost=82.3, rouge_2=0]       \n",
      "train minibatch loop:   0%|          | 0/261 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 8, avg loss: 50.593944, avg accuracy: 0.167136\n",
      "epoch: 8, avg loss test: 84.792560, avg accuracy test: 0.101978\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 261/261 [02:04<00:00,  2.12it/s, accuracy=0.0788, cost=55.4, rouge_2=0.00889]\n",
      "test minibatch loop: 100%|██████████| 14/14 [00:02<00:00,  5.96it/s, accuracy=0.0667, cost=79.8, rouge_2=0.0284] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 9, avg loss: 49.530213, avg accuracy: 0.167288\n",
      "epoch: 9, avg loss test: 80.416023, avg accuracy test: 0.043416\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "from sklearn.utils import shuffle\n",
    "import time\n",
    "\n",
    "for EPOCH in range(10):\n",
    "    lasttime = time.time()\n",
    "    total_loss, total_accuracy, total_loss_test, total_accuracy_test = 0, 0, 0, 0\n",
    "    train_X, train_Y = shuffle(train_X, train_Y)\n",
    "    test_X, test_Y = shuffle(test_X, test_Y)\n",
    "    pbar = tqdm(range(0, len(train_X), batch_size), desc='train minibatch loop')\n",
    "    for k in pbar:\n",
    "        batch_x, _ = pad_sentence_batch(train_X[k: min(k+batch_size,len(train_X))], PAD, maxlen = max_len)\n",
    "        batch_y, _ = pad_sentence_batch(train_Y[k: min(k+batch_size,len(train_X))], PAD)\n",
    "        l, acc, loss, _ = sess.run([model.training_logits, model.accuracy, model.cost, model.optimizer], \n",
    "                                      feed_dict={model.X:batch_x,\n",
    "                                                model.Y:batch_y})\n",
    "        total_loss += loss\n",
    "        total_accuracy += acc\n",
    "        r = rouge.rouge_n(np.argmax(l, axis = 2), batch_y)\n",
    "        pbar.set_postfix(cost=loss, accuracy = acc, rouge_2 = r)\n",
    "        \n",
    "    pbar = tqdm(range(0, len(test_X), batch_size), desc='test minibatch loop')\n",
    "    for k in pbar:\n",
    "        batch_x, _ = pad_sentence_batch(test_X[k: min(k+batch_size,len(test_X))], PAD, maxlen = max_len)\n",
    "        batch_y, _ = pad_sentence_batch(test_Y[k: min(k+batch_size,len(test_X))], PAD)\n",
    "        l, acc, loss = sess.run([model.training_logits, model.accuracy, model.cost], \n",
    "                                      feed_dict={model.X:batch_x,\n",
    "                                                model.Y:batch_y})\n",
    "        total_loss_test += loss\n",
    "        total_accuracy_test += acc\n",
    "        r = rouge.rouge_n(np.argmax(l, axis = 2), batch_y)\n",
    "        pbar.set_postfix(cost=loss, accuracy = acc, rouge_2 = r)\n",
    "        \n",
    "    total_loss /= (len(train_X) / batch_size)\n",
    "    total_accuracy /= (len(train_X) / batch_size)\n",
    "    total_loss_test /= (len(test_X) / batch_size)\n",
    "    total_accuracy_test /= (len(test_X) / batch_size)\n",
    "        \n",
    "    print('epoch: %d, avg loss: %f, avg accuracy: %f'%(EPOCH, total_loss, total_accuracy))\n",
    "    print('epoch: %d, avg loss test: %f, avg accuracy test: %f'%(EPOCH, total_loss_test, total_accuracy_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "b_x, _ = pad_sentence_batch([test_X[0]], PAD, maxlen = max_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, 4)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generated = sess.run(model.generate, feed_dict = {model.X: b_x})[0]\n",
    "generated.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[   1,    2,    0,    0],\n",
       "       [   1, 1208,    2,    0],\n",
       "       [   0,    0,    0,    0],\n",
       "       [   0,    0,    0,    0],\n",
       "       [   0,    0,    0,    0]], dtype=int32)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generated"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
